
import torch.nn as nn
import torchvision.models as models
import torch.nn.init as init

__all__ = [
    "resnet18",
    "resnet34",
    "resnet50",
]

def resnet18(num_classes=128, zero_init_residual=True):
    model = models.resnet18(num_classes=num_classes, zero_init_residual=zero_init_residual)
    model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
    return model

def resnet34(num_classes=128, zero_init_residual=True):
    model = models.resnet34(num_classes=num_classes, zero_init_residual=zero_init_residual)
    model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
    return model


def resnet50(num_classes=128, zero_init_residual=True):
    model = models.resnet50(num_classes=num_classes, zero_init_residual=zero_init_residual)
    model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
    return model
